{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " # Patch Time Series Transformer for Transfer Learning across datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " The `PatchTST` model was proposed in A Time Series is Worth [64 Words: Long-term Forecasting with Transformers](https://arxiv.org/abs/2211.14730) by Yuqi Nie, Nam H. Nguyen, Phanwadee Sinthong, Jayant Kalagnanam.\n",
    "\n",
    " `PatchTST` is a time-series foundation modeling approach based on the MLP-Mixer\n",
    " architecture.\n",
    "\n",
    " In this notebook, we will demonstrate the tranfer learning capability of the `PatchTST` model.\n",
    " We will pretrain the model for a forecasting task on a `source` dataset. Then, we will use the\n",
    " pretrained model for a zero-shot forecasting on a `target` dataset. The zero-shot forecasting\n",
    " performance will denote the `test` performance of the model in the `target` domain, without any\n",
    " training on the target domain. Subsequently, we will do linear probing and (then) finetuning of\n",
    " the pretrained model on the `train` part of the target data, and will validate the forecasting\n",
    " performance on the `test` part of the target data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Standard\n",
    "import os\n",
    "import random\n",
    "\n",
    "# Third Party\n",
    "from transformers import (\n",
    "    EarlyStoppingCallback,\n",
    "    PatchTSTConfig,\n",
    "    PatchTSTForPrediction,\n",
    "    Trainer,\n",
    "    TrainingArguments,\n",
    ")\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "\n",
    "# First Party\n",
    "from tsfm_public.toolkit.dataset import ForecastDFDataset\n",
    "from tsfm_public.toolkit.time_series_preprocessor import TimeSeriesPreprocessor\n",
    "from tsfm_public.toolkit.util import select_by_index"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ## Set seed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = 2023\n",
    "torch.manual_seed(SEED)\n",
    "random.seed(SEED)\n",
    "np.random.seed(SEED)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ## Load and prepare datasets\n",
    "\n",
    " In the next cell, please adjust the following parameters to suit your application:\n",
    " - PRETRAIN_AGAIN: Set this to `True` if you want to perform pretraining again. Note that this might take some time depending on the GPU availability. Otherwise, the already pretrained model will be used.\n",
    " - dataset_path: path to local .csv file, or web address to a csv file for the data of interest. Data is loaded with pandas, so anything supported by\n",
    "   `pd.read_csv` is supported: (https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html).\n",
    " - timestamp_column: column name containing timestamp information, use None if there is no such column\n",
    " - id_columns: List of column names specifying the IDs of different time series. If no ID column exists, use []\n",
    " - forecast_columns: List of columns to be modeled\n",
    " - context_length: The amount of historical data used as input to the model. Windows of the input time series data with length equal to\n",
    "   context_length will be extracted from the input dataframe. In the case of a multi-time series dataset, the context windows will be created\n",
    "   so that they are contained within a single time series (i.e., a single ID).\n",
    " - forecast_horizon: Number of time stamps to forecast in future.\n",
    " - train_start_index, train_end_index: the start and end indices in the loaded data which delineate the training data.\n",
    " - valid_start_index, eval_end_index: the start and end indices in the loaded data which delineate the validation data.\n",
    " - test_start_index, eval_end_index: the start and end indices in the loaded data which delineate the test data.\n",
    " - patch_length: The patch length for the `PatchTSMixer` model. Recommended to have a value so that `context_length` is divisible by it.\n",
    " - num_workers: Number of dataloder workers in pytorch dataloader.\n",
    " - batch_size: Batch size.\n",
    " The data is first loaded into a Pandas dataframe and split into training, validation, and test parts. Then the pandas dataframes are converted\n",
    " to the appropriate torch dataset needed for training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "PRETRAIN_AGAIN = True\n",
    "# Download ECL data from https://github.com/zhouhaoyi/Informer2020\n",
    "dataset_path = \"~/Downloads/ECL.csv\"\n",
    "timestamp_column = \"date\"\n",
    "id_columns = []\n",
    "\n",
    "context_length = 512\n",
    "forecast_horizon = 96\n",
    "patch_length = 16\n",
    "num_workers = 1\n",
    "batch_size = 16  # 128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "if PRETRAIN_AGAIN:\n",
    "    data = pd.read_csv(\n",
    "        dataset_path,\n",
    "        parse_dates=[timestamp_column],\n",
    "    )\n",
    "    forecast_columns = list(data.columns[1:])\n",
    "\n",
    "    # get split\n",
    "    num_train = int(len(data) * 0.7)\n",
    "    num_test = int(len(data) * 0.2)\n",
    "    num_valid = len(data) - num_train - num_test\n",
    "    border1s = [\n",
    "        0,\n",
    "        num_train - context_length,\n",
    "        len(data) - num_test - context_length,\n",
    "    ]\n",
    "    border2s = [num_train, num_train + num_valid, len(data)]\n",
    "\n",
    "    train_start_index = border1s[0]  # None indicates beginning of dataset\n",
    "    train_end_index = border2s[0]\n",
    "\n",
    "    # we shift the start of the evaluation period back by context length so that\n",
    "    # the first evaluation timestamp is immediately following the training data\n",
    "    valid_start_index = border1s[1]\n",
    "    valid_end_index = border2s[1]\n",
    "\n",
    "    test_start_index = border1s[2]\n",
    "    test_end_index = border2s[2]\n",
    "\n",
    "    train_data = select_by_index(\n",
    "        data,\n",
    "        id_columns=id_columns,\n",
    "        start_index=train_start_index,\n",
    "        end_index=train_end_index,\n",
    "    )\n",
    "    valid_data = select_by_index(\n",
    "        data,\n",
    "        id_columns=id_columns,\n",
    "        start_index=valid_start_index,\n",
    "        end_index=valid_end_index,\n",
    "    )\n",
    "    test_data = select_by_index(\n",
    "        data,\n",
    "        id_columns=id_columns,\n",
    "        start_index=test_start_index,\n",
    "        end_index=test_end_index,\n",
    "    )\n",
    "\n",
    "    tsp = TimeSeriesPreprocessor(\n",
    "        timestamp_column=timestamp_column,\n",
    "        id_columns=id_columns,\n",
    "        input_columns=forecast_columns,\n",
    "        output_columns=forecast_columns,\n",
    "        scaling=True,\n",
    "    )\n",
    "    tsp.train(train_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "if PRETRAIN_AGAIN:\n",
    "    train_dataset = ForecastDFDataset(\n",
    "        tsp.preprocess(train_data),\n",
    "        id_columns=id_columns,\n",
    "        timestamp_column=\"date\",\n",
    "        input_columns=forecast_columns,\n",
    "        output_columns=forecast_columns,\n",
    "        context_length=context_length,\n",
    "        prediction_length=forecast_horizon,\n",
    "    )\n",
    "    valid_dataset = ForecastDFDataset(\n",
    "        tsp.preprocess(valid_data),\n",
    "        id_columns=id_columns,\n",
    "        timestamp_column=\"date\",\n",
    "        input_columns=forecast_columns,\n",
    "        output_columns=forecast_columns,\n",
    "        context_length=context_length,\n",
    "        prediction_length=forecast_horizon,\n",
    "    )\n",
    "    test_dataset = ForecastDFDataset(\n",
    "        tsp.preprocess(test_data),\n",
    "        id_columns=id_columns,\n",
    "        timestamp_column=\"date\",\n",
    "        input_columns=forecast_columns,\n",
    "        output_columns=forecast_columns,\n",
    "        context_length=context_length,\n",
    "        prediction_length=forecast_horizon,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ## Configure the PatchTST model\n",
    "\n",
    " The settings below control the different components in the PatchTST model.\n",
    "  - num_input_channels: the number of input channels (or dimensions) in the time series data. This is\n",
    "    automatically set to the number for forecast columns.\n",
    "  - context_length: As described above, the amount of historical data used as input to the model.\n",
    "  - patch_length: The length of the patches extracted from the context window (of length `context_length`).\n",
    "  - patch_stride: The stride used when extracting patches from the context window.\n",
    "  - mask_ratio: The fraction of input patches that are completely masked for the purpose of pretraining the model.\n",
    "  - d_model: Dimension of the transformer layers.\n",
    "  - encoder_attention_heads: The number of attention heads for each attention layer in the Transformer encoder.\n",
    "  - encoder_layers: The number of encoder layers.\n",
    "  - encoder_ffn_dim: Dimension of the intermediate (often referred to as feed-forward) layer in the encoder.\n",
    "  - dropout: Dropout probability for all fully connected layers in the encoder.\n",
    "  - head_dropout: Dropout probability used in the head of the model.\n",
    "  - pooling_type: Pooling of the embedding. `\"mean\"`, `\"max\"` and `None` are supported.\n",
    "  - channel_attention: Activate channel attention block in the Transformer to allow channels to attend each other.\n",
    "  - scaling: Whether to scale the input targets via \"mean\" scaler, \"std\" scaler or no scaler if `None`. If `True`, the\n",
    "    scaler is set to `\"mean\"`.\n",
    "  - loss: The loss function for the model corresponding to the `distribution_output` head. For parametric\n",
    "    distributions it is the negative log likelihood (`\"nll\"`) and for point estimates it is the mean squared\n",
    "    error `\"mse\"`.\n",
    "  - pre_norm: Normalization is applied before self-attention if pre_norm is set to `True`. Otherwise, normalization is\n",
    "    applied after residual block.\n",
    "  - norm: Normalization at each Transformer layer. Can be `\"BatchNorm\"` or `\"LayerNorm\"`.\n",
    "\n",
    " We recommend that you only adjust the values in the next cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "if PRETRAIN_AGAIN:\n",
    "    config = PatchTSTConfig(\n",
    "        num_input_channels=len(forecast_columns),\n",
    "        context_length=context_length,\n",
    "        patch_length=patch_length,\n",
    "        patch_stride=patch_length,\n",
    "        prediction_length=forecast_horizon,\n",
    "        mask_ratio=0.4,\n",
    "        d_model=128,\n",
    "        encoder_attention_heads=16,\n",
    "        encoder_layers=3,\n",
    "        encoder_ffn_dim=256,\n",
    "        dropout=0.2,\n",
    "        head_dropout=0.2,\n",
    "        pooling_type=None,\n",
    "        channel_attention=False,\n",
    "        scaling=\"std\",\n",
    "        loss=\"mse\",\n",
    "        pre_norm=True,\n",
    "        norm=\"batchnorm\",\n",
    "    )\n",
    "    model = PatchTSTForPrediction(config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ## Train model\n",
    "\n",
    " Trains the PatchTSMixer model based on the direct forecasting strategy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='20' max='20' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [20/20 01:16, Epoch 0/1]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>0.750900</td>\n",
       "      <td>0.655789</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "if PRETRAIN_AGAIN:\n",
    "    training_args = TrainingArguments(\n",
    "        output_dir=\"./checkpoint/patchtst/electricity/pretrain/output/\",\n",
    "        overwrite_output_dir=True,\n",
    "        # learning_rate=0.001,\n",
    "        num_train_epochs=100,\n",
    "        do_eval=True,\n",
    "        evaluation_strategy=\"epoch\",\n",
    "        per_device_train_batch_size=batch_size,\n",
    "        per_device_eval_batch_size=batch_size,\n",
    "        dataloader_num_workers=1,\n",
    "        report_to=\"tensorboard\",\n",
    "        save_strategy=\"epoch\",\n",
    "        logging_strategy=\"epoch\",\n",
    "        save_total_limit=3,\n",
    "        logging_dir=\"./checkpoint/patchtst/electricity/pretrain/logs/\",  # Make sure to specify a logging directory\n",
    "        load_best_model_at_end=True,  # Load the best model when training ends\n",
    "        metric_for_best_model=\"eval_loss\",  # Metric to monitor for early stopping\n",
    "        greater_is_better=False,  # For loss\n",
    "        label_names=[\"future_values\"],\n",
    "        max_steps=20,\n",
    "    )\n",
    "\n",
    "    # Create the early stopping callback\n",
    "    early_stopping_callback = EarlyStoppingCallback(\n",
    "        early_stopping_patience=10,  # Number of epochs with no improvement after which to stop\n",
    "        early_stopping_threshold=0.0001,  # Minimum improvement required to consider as improvement\n",
    "    )\n",
    "\n",
    "    # define trainer\n",
    "    trainer = Trainer(\n",
    "        model=model,\n",
    "        args=training_args,\n",
    "        train_dataset=train_dataset,\n",
    "        eval_dataset=valid_dataset,\n",
    "        callbacks=[early_stopping_callback],\n",
    "    )\n",
    "\n",
    "    # pretrain\n",
    "    trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ## Evaluate model on the test set of the `source` domain\n",
    " Note that this is not the target metric to judge in this task."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "if PRETRAIN_AGAIN:\n",
    "    results = trainer.evaluate(test_dataset)\n",
    "    print(\"Test result:\")\n",
    "    print(results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ## Save model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "if PRETRAIN_AGAIN:\n",
    "    save_dir = \"patchtst/electricity/model/pretrain/\"\n",
    "    os.makedirs(save_dir, exist_ok=True)\n",
    "    trainer.save_model(save_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ## Transfer Learing on ETTh1 data. All evaluations are on the `test` part of the ETTh1 data.\n",
    " Step 1: Directly evaluate the electricity-pretrained model. This is the zero-shot performance.\n",
    " Step 2: Evalute after doing linear probing.\n",
    " Step 3: Evaluate after doing full finetuning."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ## Load ETTH data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = \"ETTh1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading target dataset: ETTh1\n"
     ]
    }
   ],
   "source": [
    "print(f\"Loading target dataset: {dataset}\")\n",
    "dataset_path = f\"https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/{dataset}.csv\"\n",
    "timestamp_column = \"date\"\n",
    "id_columns = []\n",
    "forecast_columns = [\"HUFL\", \"HULL\", \"MUFL\", \"MULL\", \"LUFL\", \"LULL\", \"OT\"]\n",
    "train_start_index = None  # None indicates beginning of dataset\n",
    "train_end_index = 12 * 30 * 24\n",
    "\n",
    "# we shift the start of the evaluation period back by context length so that\n",
    "# the first evaluation timestamp is immediately following the training data\n",
    "valid_start_index = 12 * 30 * 24 - context_length\n",
    "valid_end_index = 12 * 30 * 24 + 4 * 30 * 24\n",
    "\n",
    "test_start_index = 12 * 30 * 24 + 4 * 30 * 24 - context_length\n",
    "test_end_index = 12 * 30 * 24 + 8 * 30 * 24"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TimeSeriesPreprocessor {\n",
       "  \"context_length\": 64,\n",
       "  \"feature_extractor_type\": \"TimeSeriesPreprocessor\",\n",
       "  \"id_columns\": [],\n",
       "  \"input_columns\": [\n",
       "    \"HUFL\",\n",
       "    \"HULL\",\n",
       "    \"MUFL\",\n",
       "    \"MULL\",\n",
       "    \"LUFL\",\n",
       "    \"LULL\",\n",
       "    \"OT\"\n",
       "  ],\n",
       "  \"output_columns\": [\n",
       "    \"HUFL\",\n",
       "    \"HULL\",\n",
       "    \"MUFL\",\n",
       "    \"MULL\",\n",
       "    \"LUFL\",\n",
       "    \"LULL\",\n",
       "    \"OT\"\n",
       "  ],\n",
       "  \"prediction_length\": null,\n",
       "  \"processor_class\": \"TimeSeriesPreprocessor\",\n",
       "  \"scaler_dict\": {\n",
       "    \"0\": {\n",
       "      \"copy\": true,\n",
       "      \"feature_names_in_\": [\n",
       "        \"HUFL\",\n",
       "        \"HULL\",\n",
       "        \"MUFL\",\n",
       "        \"MULL\",\n",
       "        \"LUFL\",\n",
       "        \"LULL\",\n",
       "        \"OT\"\n",
       "      ],\n",
       "      \"mean_\": [\n",
       "        7.937742245659508,\n",
       "        2.0210386567335163,\n",
       "        5.079770601157927,\n",
       "        0.7461858799957015,\n",
       "        2.781762386375555,\n",
       "        0.7884531235540096,\n",
       "        17.1282616982271\n",
       "      ],\n",
       "      \"n_features_in_\": 7,\n",
       "      \"n_samples_seen_\": 8640,\n",
       "      \"scale_\": [\n",
       "        5.812749409143771,\n",
       "        2.0901046504076,\n",
       "        5.518793579036245,\n",
       "        1.9263792741329822,\n",
       "        1.0235226594952194,\n",
       "        0.6302366362251923,\n",
       "        9.176491024944335\n",
       "      ],\n",
       "      \"var_\": [\n",
       "        33.78805569350125,\n",
       "        4.368537449655475,\n",
       "        30.457082568011693,\n",
       "        3.710937107809115,\n",
       "        1.0475986345001667,\n",
       "        0.39719821764044544,\n",
       "        84.20798753088393\n",
       "      ],\n",
       "      \"with_mean\": true,\n",
       "      \"with_std\": true\n",
       "    }\n",
       "  },\n",
       "  \"scaling\": true,\n",
       "  \"time_series_task\": \"forecasting\",\n",
       "  \"timestamp_column\": \"date\"\n",
       "}"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = pd.read_csv(\n",
    "    dataset_path,\n",
    "    parse_dates=[timestamp_column],\n",
    ")\n",
    "\n",
    "train_data = select_by_index(\n",
    "    data,\n",
    "    id_columns=id_columns,\n",
    "    start_index=train_start_index,\n",
    "    end_index=train_end_index,\n",
    ")\n",
    "valid_data = select_by_index(\n",
    "    data,\n",
    "    id_columns=id_columns,\n",
    "    start_index=valid_start_index,\n",
    "    end_index=valid_end_index,\n",
    ")\n",
    "test_data = select_by_index(\n",
    "    data,\n",
    "    id_columns=id_columns,\n",
    "    start_index=test_start_index,\n",
    "    end_index=test_end_index,\n",
    ")\n",
    "\n",
    "tsp = TimeSeriesPreprocessor(\n",
    "    timestamp_column=timestamp_column,\n",
    "    id_columns=id_columns,\n",
    "    input_columns=forecast_columns,\n",
    "    output_columns=forecast_columns,\n",
    "    scaling=True,\n",
    ")\n",
    "tsp.train(train_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = ForecastDFDataset(\n",
    "    tsp.preprocess(train_data),\n",
    "    id_columns=id_columns,\n",
    "    input_columns=forecast_columns,\n",
    "    output_columns=forecast_columns,\n",
    "    context_length=context_length,\n",
    "    prediction_length=forecast_horizon,\n",
    ")\n",
    "valid_dataset = ForecastDFDataset(\n",
    "    tsp.preprocess(valid_data),\n",
    "    id_columns=id_columns,\n",
    "    input_columns=forecast_columns,\n",
    "    output_columns=forecast_columns,\n",
    "    context_length=context_length,\n",
    "    prediction_length=forecast_horizon,\n",
    ")\n",
    "test_dataset = ForecastDFDataset(\n",
    "    tsp.preprocess(test_data),\n",
    "    id_columns=id_columns,\n",
    "    input_columns=forecast_columns,\n",
    "    output_columns=forecast_columns,\n",
    "    context_length=context_length,\n",
    "    prediction_length=forecast_horizon,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ## Zero-shot forecasting on ETTH"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading pretrained model\n",
      "Done\n"
     ]
    }
   ],
   "source": [
    "print(\"Loading pretrained model\")\n",
    "finetune_forecast_model = PatchTSTForPrediction.from_pretrained(\n",
    "    \"patchtst/electricity/model/pretrain/\", num_input_channels=len(forecast_columns)\n",
    ")\n",
    "print(\"Done\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Doing zero-shot forecasting on target data\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='350' max='175' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [175/175 00:25]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Target data zero-shot forecasting result:\n",
      "{'eval_loss': 0.46627211570739746, 'eval_runtime': 7.6504, 'eval_samples_per_second': 364.035, 'eval_steps_per_second': 22.875}\n"
     ]
    }
   ],
   "source": [
    "finetune_forecast_args = TrainingArguments(\n",
    "    output_dir=\"./checkpoint/patchtst/transfer/finetune/output/\",\n",
    "    overwrite_output_dir=True,\n",
    "    learning_rate=0.0001,\n",
    "    num_train_epochs=100,\n",
    "    do_eval=True,\n",
    "    evaluation_strategy=\"epoch\",\n",
    "    per_device_train_batch_size=batch_size,\n",
    "    per_device_eval_batch_size=batch_size,\n",
    "    dataloader_num_workers=num_workers,\n",
    "    report_to=\"tensorboard\",\n",
    "    save_strategy=\"epoch\",\n",
    "    logging_strategy=\"epoch\",\n",
    "    save_total_limit=3,\n",
    "    logging_dir=\"./checkpoint/patchtst/transfer/finetune/logs/\",  # Make sure to specify a logging directory\n",
    "    load_best_model_at_end=True,  # Load the best model when training ends\n",
    "    metric_for_best_model=\"eval_loss\",  # Metric to monitor for early stopping\n",
    "    greater_is_better=False,  # For loss\n",
    "    label_names=[\"future_values\"],\n",
    ")\n",
    "\n",
    "# Create a new early stopping callback with faster convergence properties\n",
    "early_stopping_callback = EarlyStoppingCallback(\n",
    "    early_stopping_patience=5,  # Number of epochs with no improvement after which to stop\n",
    "    early_stopping_threshold=0.001,  # Minimum improvement required to consider as improvement\n",
    ")\n",
    "\n",
    "finetune_forecast_trainer = Trainer(\n",
    "    model=finetune_forecast_model,\n",
    "    args=finetune_forecast_args,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=valid_dataset,\n",
    "    callbacks=[early_stopping_callback],\n",
    ")\n",
    "\n",
    "print(\"\\n\\nDoing zero-shot forecasting on target data\")\n",
    "result = finetune_forecast_trainer.evaluate(test_dataset)\n",
    "print(\"Target data zero-shot forecasting result:\")\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ## Target data linear probing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Linear probing on the target data\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='7545' max='50300' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [ 7545/50300 04:46 < 27:02, 26.35 it/s, Epoch 15/100]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.372600</td>\n",
       "      <td>0.684668</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.353300</td>\n",
       "      <td>0.679172</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.347700</td>\n",
       "      <td>0.676667</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.345100</td>\n",
       "      <td>0.669841</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.342100</td>\n",
       "      <td>0.667121</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.340500</td>\n",
       "      <td>0.668182</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.339600</td>\n",
       "      <td>0.672160</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.338600</td>\n",
       "      <td>0.664347</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>0.337700</td>\n",
       "      <td>0.672996</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>0.336800</td>\n",
       "      <td>0.658911</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11</td>\n",
       "      <td>0.336100</td>\n",
       "      <td>0.670552</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12</td>\n",
       "      <td>0.335500</td>\n",
       "      <td>0.662045</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13</td>\n",
       "      <td>0.334500</td>\n",
       "      <td>0.670143</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14</td>\n",
       "      <td>0.334300</td>\n",
       "      <td>0.662012</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>15</td>\n",
       "      <td>0.333900</td>\n",
       "      <td>0.679170</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='175' max='175' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [175/175 00:06]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Target data head/linear probing result:\n",
      "{'eval_loss': 0.36961913108825684, 'eval_runtime': 7.4862, 'eval_samples_per_second': 372.02, 'eval_steps_per_second': 23.376, 'epoch': 15.0}\n"
     ]
    }
   ],
   "source": [
    "# Freeze the backbone of the model\n",
    "for param in finetune_forecast_trainer.model.model.parameters():\n",
    "    param.requires_grad = False\n",
    "\n",
    "print(\"\\n\\nLinear probing on the target data\")\n",
    "finetune_forecast_trainer.train()\n",
    "print(\"Evaluating\")\n",
    "result = finetune_forecast_trainer.evaluate(test_dataset)\n",
    "print(\"Target data head/linear probing result:\")\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['patchtst/electricity/model/transfer/ETTh1/preprocessor/preprocessor_config.json']"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "save_dir = f\"patchtst/electricity/model/transfer/{dataset}/model/linear_probe/\"\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "finetune_forecast_trainer.save_model(save_dir)\n",
    "\n",
    "save_dir = f\"patchtst/electricity/model/transfer/{dataset}/preprocessor/\"\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "tsp.save_pretrained(save_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ## Target data full finetune"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Finetuning on the target data\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='3018' max='50300' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [ 3018/50300 02:58 < 46:43, 16.86 it/s, Epoch 6/100]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.359700</td>\n",
       "      <td>0.707633</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.327000</td>\n",
       "      <td>0.731749</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.308800</td>\n",
       "      <td>0.789940</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.292800</td>\n",
       "      <td>0.860487</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.278200</td>\n",
       "      <td>0.925161</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.263900</td>\n",
       "      <td>0.887657</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='175' max='175' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [175/175 00:06]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Target data full finetune result:\n",
      "{'eval_loss': 0.37372443079948425, 'eval_runtime': 7.5282, 'eval_samples_per_second': 369.944, 'eval_steps_per_second': 23.246, 'epoch': 6.0}\n"
     ]
    }
   ],
   "source": [
    "# Reload the model\n",
    "finetune_forecast_model = PatchTSTForPrediction.from_pretrained(\n",
    "    \"patchtst/electricity/model/pretrain/\", num_input_channels=len(forecast_columns)\n",
    ")\n",
    "finetune_forecast_trainer = Trainer(\n",
    "    model=finetune_forecast_model,\n",
    "    args=finetune_forecast_args,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=valid_dataset,\n",
    "    callbacks=[early_stopping_callback],\n",
    ")\n",
    "print(\"\\n\\nFinetuning on the target data\")\n",
    "finetune_forecast_trainer.train()\n",
    "print(\"Evaluating\")\n",
    "result = finetune_forecast_trainer.evaluate(test_dataset)\n",
    "print(\"Target data full finetune result:\")\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_dir = f\"patchtst/electricity/model/transfer/{dataset}/model/fine_tuning/\"\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "finetune_forecast_trainer.save_model(save_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
